# -*- coding: utf-8 -*-
"""
author:LTH
data:
"""

import torch

from config import GetConfig
from datalist import train_dataset

class_num = len(train_dataset.classes)

args = GetConfig()


def one_hot(label):
    label = label.resize_(args.train_batch_size, 1)
    m_zeros = torch.zeros(args.train_batch_size, class_num)
    onehot = m_zeros.scatter_(1, label.cpu(), 1)
    return onehot
